## Calculate features for nuclei and generate .pt files for each graph
## 15 features in total

In [1]:
import re, os
import cv2
import math
import random
import torch
import resnet
import skimage.feature
import pdb
from PIL import Image
from pyflann import *
from torch_geometric.data import Data
from collections import OrderedDict

import networkx as nx
import numpy as np
import pandas as pd
import torchvision.transforms.functional as F
import torch_geometric.data as data
import torch_geometric.utils as utils
import pdb
import torch_geometric

In [2]:
from model import CPC_model
device = torch.device('cuda:{}'.format('0'))
model = CPC_model(1024, 256)
encoder = model.encoder.to(device)
ckpt_dir = './pretrained_models/cpc.pt'
ckpt = torch.load(ckpt_dir)
encoder.load_state_dict(ckpt['encoder_state_dict'])

<All keys matched successfully>

In [4]:
def from_networkx(G):
    r"""Converts a :obj:`networkx.Graph` or :obj:`networkx.DiGraph` to a
    :class:`torch_geometric.data.Data` instance.
    Args:
        G (networkx.Graph or networkx.DiGraph): A networkx graph.
    """

    G = G.to_directed() if not nx.is_directed(G) else G
    edge_index = torch.tensor(list(G.edges)).t().contiguous()

    keys = []
    keys += list(list(G.nodes(data=True))[0][1].keys())
    keys += list(list(G.edges(data=True))[0][2].keys())
    data = {key: [] for key in keys}

    for _, feat_dict in G.nodes(data=True):
        for key, value in feat_dict.items():
            data[key].append(value)

    for _, _, feat_dict in G.edges(data=True):
        for key, value in feat_dict.items():
            data[key].append(value)

    for key, item in data.items():
        data[key] = torch.tensor(item)

    data['edge_index'] = edge_index
    data = torch_geometric.data.Data.from_dict(data)
    data.num_nodes = G.number_of_nodes()

    return data

In [6]:
from torchvision import transforms
import itertools

def get_cell_image(img, cx, cy, size=512):
    cx = 32 if cx < 32 else size-32 if cx > size-32 else cx
    cy = 32 if cy < 32 else size-32 if cy > size-32 else cy
    return img[cy-32:cy+32, cx-32:cx+32, :]

def get_cpc_features(cell):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    cell = transform(cell)
    cell = cell.unsqueeze(0)
    device = torch.device('cuda:{}'.format('0'))
    feats = encoder(cell.to(device)).cpu().detach().numpy()[0]
    return feats

def get_cell_features(img, contour):
    
    # Get contour coordinates from contour
    (cx, cy), (short_axis, long_axis), angle = cv2.fitEllipse(contour)
    cx, cy = int(cx), int(cy)
    
    # Get a 64 x 64 center crop over each cell    
    img_cell = get_cell_image(img, cx, cy)

    grey_region = cv2.cvtColor(img_cell, cv2.COLOR_RGB2GRAY)
    img_cell_grey = np.pad(grey_region, [(0, 64-grey_region.shape[0]), (0, 64-grey_region.shape[1])], mode = 'reflect') 


    # 1. Generating contour features
    eccentricity = math.sqrt(1-(short_axis/long_axis)**2)
    convex_hull = cv2.convexHull(contour)
    area, hull_area = cv2.contourArea(contour), cv2.contourArea(convex_hull)
    solidity = float(area)/hull_area
    arc_length = cv2.arcLength(contour, True)
    roundness = (arc_length/(2*math.pi))/(math.sqrt(area/math.pi))
    
    # 2. Generating GLCM features
    out_matrix = skimage.feature.greycomatrix(img_cell_grey, [1], [0])
    dissimilarity = skimage.feature.greycoprops(out_matrix, 'dissimilarity')[0][0]
    homogeneity = skimage.feature.greycoprops(out_matrix, 'homogeneity')[0][0]
    energy = skimage.feature.greycoprops(out_matrix, 'energy')[0][0]
    ASM = skimage.feature.greycoprops(out_matrix, 'ASM')[0][0]
    
    # 3. Generating CPC features
    cpc_feats = get_cpc_features(img_cell)
    

    # Concatenate + Return all features
    x = [[short_axis, long_axis, angle, area, arc_length, eccentricity, roundness, solidity],
         [dissimilarity, homogeneity, energy, ASM], 
         cpc_feats]
    
    return np.array(list(itertools.chain(*x)), dtype=np.float64), cx, cy


def seg2graph(img, contours):
    G = nx.Graph()
    
    contours = [c for c in contours if c.shape[0] > 5]

    for v, contour in enumerate(contours):

        features, cx, cy = get_cell_features(img, contour)
        G.add_node(v, centroid = [cx, cy], x = features)

    if v < 5: return None
    return G

In [1]:
data_dir = "./example_data/"
img_dir = os.path.join(data_dir, 'imgs')
seg_dir =  os.path.join(data_dir,'segs')

roi1 = 'TCGA-06-0174-01Z-00-DX3.23b6e12e-dfc1-4c6f-903e-170038a0e055_1.png'
roi2 = 'TCGA-HT-7470-01Z-00-DX4.204D0CF2-A22E-4428-8E8C-572432B86280_1.png'
roi3 = 'TCGA-26-1442-01Z-00-DX1.FD8D4EB7-AD5E-49E8-BD0B-6CDDEA8DDB35_1.png'

assert roi1 in os.listdir(seg_dir)
assert roi2 in os.listdir(seg_dir)
assert roi3 in os.listdir(seg_dir)

In [26]:
save_dir = data_dir
pt_dir = os.path.join(save_dir, 'pts')
graph_dir = os.path.join(save_dir, 'graphs')
fail_list = []

from tqdm import tqdm

for img_fname in tqdm([roi1, roi2, roi3]):
    
    #if int(img_fname.split('_')[2]) > 2: continue
    #print("Processing...(%d/%d):\t%s" % (idx+1, len(os.listdir(seg_dir)), img_fname))
    
    img = np.array(Image.open(os.path.join(img_dir, img_fname)))
    seg = np.array(Image.open(os.path.join(seg_dir, img_fname)))
    ret, binary = cv2.threshold(seg, 127, 255, cv2.THRESH_BINARY) 
    contours, hierarchy = cv2.findContours(binary, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    if len(contours) < 1: continue
    
    G = seg2graph(img, contours)

    if G is None: 
        fail_list.append(img_fname)
        continue


    centroids = []
    for u, attrib in G.nodes(data=True):
        centroids.append(attrib['centroid'])
    
    cell_centroids = np.array(centroids).astype(np.float64)
    dataset = cell_centroids
    
    start = None
            
    for idx, attrib in list(G.nodes(data=True)):
        start = idx
        flann = FLANN()
        testset = np.array([attrib['centroid']]).astype(np.float64)
        results, dists = flann.nn(dataset, testset, num_neighbors=5, algorithm = 'kmeans', branching = 32, iterations = 100, checks = 16)
        results, dists = results[0], dists[0]
        nns_fin = []
       # assert (results.shape[0] < 6)
        for i in range(1, len(results)):
            G.add_edge(idx, results[i], weight = dists[i])
            nns_fin.append(results[i])
        #attrib['nn'] = list(nns_fin)

    G = G.subgraph(max(nx.connected_components(G), key=len))

    #for idx, attrib in list(G.nodes(data=True)):
    #    cv2.circle(img, tuple(attrib['centroid']), 8, (0, 255, 0), -1)
    
    cv2.drawContours(img, contours, -1, (0,255,0), 2)
    
    for n, nbrs in G.adjacency():
        for nbr, eattr in nbrs.items():
            cv2.line(img, tuple(G.nodes[n]['centroid']),  tuple(G.nodes[nbr]['centroid']), (0, 0, 255), 2)

    Image.fromarray(img).save(os.path.join(graph_dir, img_fname))
    
    G = from_networkx(G)
    
    edge_attr_long = (G.weight.unsqueeze(1)).type(torch.LongTensor)
    G.edge_attr = edge_attr_long 
    
    edge_index_long = G['edge_index'].type(torch.LongTensor)
    G.edge_index = edge_index_long
    
    x_float = G['x'].type(torch.FloatTensor)
    G.x = x_float
    
    G['weight'] = None
    G['nn'] = None
    torch.save(G, os.path.join(pt_dir, img_fname[:-4]+'.pt'))

100%|██████████| 3/3 [00:13<00:00,  4.41s/it]
